{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Testing Individual components of the FV HE scheme"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "from syft.frameworks.torch.he.fv.modulus import CoeffModulus\n",
    "from syft.frameworks.torch.he.fv.encryption_params import EncryptionParams\n",
    "from syft.frameworks.torch.he.fv.context import Context\n",
    "from syft.frameworks.torch.he.fv.integer_encoder import IntegerEncoder\n",
    "from syft.frameworks.torch.he.fv.key_generator import KeyGenerator\n",
    "from syft.frameworks.torch.he.fv.encryptor import Encryptor\n",
    "from syft.frameworks.torch.he.fv.decryptor import Decryptor\n",
    "from syft.frameworks.torch.he.fv.integer_encoder import IntegerEncoder\n",
    "from syft.frameworks.torch.he.fv.modulus import SeqLevelType\n",
    "from syft.frameworks.torch.he.fv.evaluator import Evaluator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Keygeneration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "poly_modulus = 64\n",
    "bit_sizes= [40]\n",
    "plain_modulus = 64\n",
    "ctx = Context(EncryptionParams(poly_modulus, CoeffModulus().create(poly_modulus, bit_sizes), plain_modulus))\n",
    "keygenerator = KeyGenerator(ctx)\n",
    "sk, pk = keygenerator.keygen()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "error",
     "ename": "AttributeError",
     "evalue": "'Context' object has no attribute 'param'",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-3-53b9ef16a468>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcoeff_modulus\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m: 'Context' object has no attribute 'param'"
     ]
    }
   ],
   "source": [
    "print(ctx.param.coeff_modulus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "secret key values :  [[1, 1, 0, 0, 1, 1099511623296, 1099511623296, 1099511623296, 1099511623296, 1099511623296, 0, 0, 1, 0, 0, 1099511623296, 1099511623296, 0, 1099511623296, 1, 1099511623296, 1, 1099511623296, 1099511623296, 1, 1099511623296, 1099511623296, 1, 0, 0, 1099511623296, 1, 1, 0, 1, 1099511623296, 0, 1099511623296, 0, 1, 1, 1, 1, 0, 0, 1099511623296, 0, 1099511623296, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1099511623296, 1099511623296, 1, 1, 1099511623296]]\n"
    }
   ],
   "source": [
    "# print(len(sk.data))\n",
    "print('secret key values : ', sk.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(pk.data)\n",
    "# print('public key values : ', pk.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Integer Encoder\n",
    "Encodes Integer values to Plaintext object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "[0, 1, 1]     [1, 0, 1]     [1]\n"
    }
   ],
   "source": [
    "int_encoder = IntegerEncoder(ctx)\n",
    "ri1 = random.randint(0,10)\n",
    "ri2 = random.randint(0,10)\n",
    "ri3 = random.randint(0,10)\n",
    "pt1 = int_encoder.encode(ri1)\n",
    "pt2 = int_encoder.encode(ri2)\n",
    "pt3 = int_encoder.encode(ri3)\n",
    "print(pt1.data,\"   \", pt2.data, \"   \", pt3.data)\n",
    "# print('plaintext data',plaintext.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Decodes back to Integer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "6\n5\n1\n"
    }
   ],
   "source": [
    "print(int_encoder.decode(pt1))\n",
    "print(int_encoder.decode(pt2))\n",
    "print(int_encoder.decode(pt3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Encrypter\n",
    "Encrypt Plaintext to ciphertext using public_key"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "encrypter = Encryptor(ctx, pk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "ct1 = encrypter.encrypt(pt1)\n",
    "ct2 = encrypter.encrypt(pt2)\n",
    "ct3 = encrypter.encrypt(pt3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Encrypt Plaintext to ciphertext using secret_key"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decryptor\n",
    "Decrypt Ciphertext to Plaintext using secret_key"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "decrypter = Decryptor(ctx, sk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dec1 = decrypter.decrypt(ct1)\n",
    "dec2 = decrypter.decrypt(ct2)\n",
    "dec3 = decrypter.decrypt(ct3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "6     5     1\n"
    }
   ],
   "source": [
    "print(int_encoder.decode(dec1), \"   \", int_encoder.decode(dec2), \"   \", int_encoder.decode(dec3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval = Evaluator(ctx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "11\n"
    }
   ],
   "source": [
    "cc12 = eval.add(ct1, ct2)\n",
    "cc12 = decrypter.decrypt(cc12)\n",
    "print(int_encoder.decode(cc12))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "11\n"
    }
   ],
   "source": [
    "pc12 = eval.add(pt1, ct2)\n",
    "pc12 = decrypter.decrypt(pc12)\n",
    "print(int_encoder.decode(pc12))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "11\n"
    }
   ],
   "source": [
    "pp12 = eval.add(pt1, pt2)\n",
    "print(int_encoder.decode(pp12))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Verify result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert int_encoder.decode(cc12) == int_encoder.decode(pc12) == int_encoder.decode(pp12) == ri1+ri2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "\n\nct1 : [[[1067512001505, 350015162115, 311406443496, 80496350393, 524373797622, 742644556718, 62936579817, 52310147176, 385997988137, 118689484330, 498592338647, 719074861900, 1003447586295, 858667985260, 966369592686, 367520268254, 1057660243718, 880915047276, 416861344029, 108098925421, 181357590484, 618737843735, 507024383305, 897795528943, 129496199136, 481259278936, 202137923983, 453766831890, 890176989942, 806284886737, 715515666539, 690636942381, 262248549485, 221766004483, 692635831460, 607105351283, 951468695372, 586734520164, 157844984804, 924728710547, 395194259874, 438000719231, 610057978406, 458477054482, 439488609397, 1087479101999, 1091650373606, 762559739334, 1248042238, 155190016541, 103883267053, 992901830940, 354167052202, 878581966717, 904072168396, 1035942765410, 240536204934, 490619440102, 813891009974, 962058061041, 357201344497, 105272954871, 246881996360, 999577688370]], [[503483181481, 574119001787, 746005074219, 108394207121, 81257705747, 426364969775, 922883512247, 777572990385, 789493826348, 865512506928, 731724785681, 703603879152, 217197876300, 116444949996, 119136344439, 253711912198, 572421420990, 102059650924, 659990904404, 909761866886, 251489721426, 580442990765, 131033300908, 378431484307, 924618627004, 902851689199, 584447932245, 793570097415, 610147427413, 1062822188539, 1045691868484, 1020069170836, 50010874056, 14836399599, 309292706091, 418297664393, 811635219942, 994772647418, 497406013360, 667134669140, 99372557433, 704206646507, 321329237303, 260732581922, 511209635255, 130895529670, 686723430632, 482699064393, 1051755916824, 544394118055, 50069661931, 807910962107, 571856325820, 812713414632, 170590079493, 306148842665, 971747643570, 795340893542, 291869442988, 178013095182, 308973297825, 26878092339, 1003086852621, 838008105657]]]\n\n\nct2 : [[[797044452554, 90106576012, 518464916346, 205494380543, 410721316462, 737179561562, 497163997894, 422143686373, 976236938933, 857087222213, 71574364928, 275950338705, 330515544583, 667445845318, 828177875587, 143414286372, 810392243172, 459853915791, 725105698716, 431716302425, 837891457994, 334562630497, 1044738506867, 696650009648, 79651297148, 20026371725, 653395768469, 122613428735, 172997958778, 620755446164, 886212775338, 542839887607, 114030667779, 1098573584359, 540902354811, 673594868520, 1000842146349, 933264147025, 15719260261, 152525271869, 928146884715, 146728931011, 508454656576, 874137856717, 477369281335, 50534534501, 806380450208, 653052697708, 80087268273, 675925673601, 639967763400, 558033951462, 502787237599, 383902830007, 612394339793, 566637683339, 86537656788, 858127595840, 838926295606, 862696103879, 266804950093, 21883620547, 145872665970, 30827864033]], [[327949284501, 556943994185, 484302683071, 985497243568, 380665605625, 742393046623, 797051948803, 412126981324, 464428089621, 819585508921, 1018740165306, 164795545357, 1080408684019, 849499923841, 499585925997, 492185614494, 990432491283, 174392952127, 530515618534, 1028748957280, 433318622652, 946922121866, 69756658165, 664446169102, 380828099086, 845363386808, 64707590928, 13187833924, 303731848148, 876217055094, 868279024975, 1062678822174, 958680518608, 411291703682, 173134135607, 102768657745, 889858682638, 929475004881, 392237797329, 637354097520, 945668831701, 179654613603, 1046470152054, 598007509898, 295284241392, 554045294829, 844491478700, 1085130505296, 513043238162, 729212038133, 847031543906, 593063847918, 1032488230909, 792548263502, 44002608600, 564351191064, 534484064589, 463667948886, 172461882919, 802369338564, 613102182710, 1055761040048, 845824275491, 599377761876]]]\n\n\n\nfinal result:  30\n"
    }
   ],
   "source": [
    "result = eval._mul_cipher_cipher(ct1, ct2)\n",
    "print(\"\\n\\nct1 :\",ct1.data)\n",
    "print(\"\\n\\nct2 :\",ct2.data)\n",
    "print('\\n\\n')\n",
    "\n",
    "result = decrypter.decrypt(result)\n",
    "result = int_encoder.decode(result)\n",
    "\n",
    "print('final result: ', result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "30      30\n"
    }
   ],
   "source": [
    "print(ri1 * ri2, \"    \", result)\n",
    "assert ri1 * ri2 == result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Try Relinearization operation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "poly_modulus = 64\n",
    "bit_sizes = [40, 40]\n",
    "plain_modulus = 64\n",
    "ctx = Context(EncryptionParams(poly_modulus, CoeffModulus().create(poly_modulus, bit_sizes), plain_modulus))\n",
    "keygenerator = KeyGenerator(ctx)\n",
    "sk, pk = keygenerator.keygen()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "relin_keys = keygenerator.get_relin_keys()\n",
    "# relin_keys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "int_encoder = IntegerEncoder(ctx)\n",
    "a = int_encoder.encode(10)\n",
    "b = int_encoder.encode(19)\n",
    "c = int_encoder.encode(26)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "encrypter = Encryptor(ctx, pk)\n",
    "decrypter = Decryptor(ctx, sk)\n",
    "eval = Evaluator(ctx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "190\n"
    }
   ],
   "source": [
    "relin_prod_ab = eval.relin(eval.mul(encrypter.encrypt(a), encrypter.encrypt(b)), relin_keys)\n",
    "print(int_encoder.decode(decrypter.decrypt(relin_prod_ab)))\n",
    "\n",
    "assert len(relin_prod_ab.data) == 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "4940\n"
    }
   ],
   "source": [
    "relin_prod_abc = eval.relin(eval.mul(relin_prod_ab, encrypter.encrypt(c)), relin_keys)\n",
    "print(int_encoder.decode(decrypter.decrypt(relin_prod_abc)))\n",
    "\n",
    "assert len(relin_prod_abc.data) == 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "938600\n"
    }
   ],
   "source": [
    "final = eval.relin(eval.mul(relin_prod_ab, relin_prod_abc), relin_keys)\n",
    "print(int_encoder.decode(decrypter.decrypt(final)))\n",
    "\n",
    "assert len(final.data) == 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "file_extension": ".py",
  "kernelspec": {
   "display_name": "Python 3.8.5 64-bit",
   "language": "python",
   "name": "python_defaultSpec_1597056993494"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  },
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "version": 3
 },
 "nbformat": 4,
 "nbformat_minor": 2
}